{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Weighted Least Squares"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import statsmodels.api as sm\n",
    "from scipy import stats\n",
    "from statsmodels.iolib.table import SimpleTable, default_txt_fmt\n",
    "\n",
    "np.random.seed(1024)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## WLS Estimation\n",
    "\n",
    "### Artificial data: Heteroscedasticity 2 groups \n",
    "\n",
    "Model assumptions:\n",
    "\n",
    " * Misspecification: true model is quadratic, estimate only linear\n",
    " * Independent noise/error term\n",
    " * Two groups for error variance, low and high variance groups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nsample = 50\n",
    "x = np.linspace(0, 20, nsample)\n",
    "X = np.column_stack((x, (x - 5) ** 2))\n",
    "X = sm.add_constant(X)\n",
    "beta = [5.0, 0.5, -0.01]\n",
    "sig = 0.5\n",
    "w = np.ones(nsample)\n",
    "w[nsample * 6 // 10 :] = 3\n",
    "y_true = np.dot(X, beta)\n",
    "e = np.random.normal(size=nsample)\n",
    "y = y_true + sig * w * e\n",
    "X = X[:, [0, 1]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### WLS knowing the true variance ratio of heteroscedasticity\n",
    "\n",
    "In this example, `w` is the standard deviation of the error.  `WLS` requires that the weights are proportional to the inverse of the error variance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod_wls = sm.WLS(y, X, weights=1.0 / (w**2))\n",
    "res_wls = mod_wls.fit()\n",
    "print(res_wls.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## OLS vs. WLS\n",
    "\n",
    "Estimate an OLS model for comparison: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_ols = sm.OLS(y, X).fit()\n",
    "print(res_ols.params)\n",
    "print(res_wls.params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare the WLS standard errors to  heteroscedasticity corrected OLS standard errors:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "se = np.vstack(\n",
    "    [\n",
    "        [res_wls.bse],\n",
    "        [res_ols.bse],\n",
    "        [res_ols.HC0_se],\n",
    "        [res_ols.HC1_se],\n",
    "        [res_ols.HC2_se],\n",
    "        [res_ols.HC3_se],\n",
    "    ]\n",
    ")\n",
    "se = np.round(se, 4)\n",
    "colnames = [\"x1\", \"const\"]\n",
    "rownames = [\"WLS\", \"OLS\", \"OLS_HC0\", \"OLS_HC1\", \"OLS_HC3\", \"OLS_HC3\"]\n",
    "tabl = SimpleTable(se, colnames, rownames, txt_fmt=default_txt_fmt)\n",
    "print(tabl)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculate OLS prediction interval:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "covb = res_ols.cov_params()\n",
    "prediction_var = res_ols.mse_resid + (X * np.dot(covb, X.T).T).sum(1)\n",
    "prediction_std = np.sqrt(prediction_var)\n",
    "tppf = stats.t.ppf(0.975, res_ols.df_resid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_ols = res_ols.get_prediction()\n",
    "iv_l_ols = pred_ols.summary_frame()[\"obs_ci_lower\"]\n",
    "iv_u_ols = pred_ols.summary_frame()[\"obs_ci_upper\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Draw a plot to compare predicted values in WLS and OLS:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_wls = res_wls.get_prediction()\n",
    "iv_l = pred_wls.summary_frame()[\"obs_ci_lower\"]\n",
    "iv_u = pred_wls.summary_frame()[\"obs_ci_upper\"]\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(8, 6))\n",
    "ax.plot(x, y, \"o\", label=\"Data\")\n",
    "ax.plot(x, y_true, \"b-\", label=\"True\")\n",
    "# OLS\n",
    "ax.plot(x, res_ols.fittedvalues, \"r--\")\n",
    "ax.plot(x, iv_u_ols, \"r--\", label=\"OLS\")\n",
    "ax.plot(x, iv_l_ols, \"r--\")\n",
    "# WLS\n",
    "ax.plot(x, res_wls.fittedvalues, \"g--.\")\n",
    "ax.plot(x, iv_u, \"g--\", label=\"WLS\")\n",
    "ax.plot(x, iv_l, \"g--\")\n",
    "ax.legend(loc=\"best\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Feasible Weighted Least Squares (2-stage FWLS)\n",
    "\n",
    "Like `w`, `w_est` is proportional to the standard deviation, and so must be squared."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resid1 = res_ols.resid[w == 1.0]\n",
    "var1 = resid1.var(ddof=int(res_ols.df_model) + 1)\n",
    "resid2 = res_ols.resid[w != 1.0]\n",
    "var2 = resid2.var(ddof=int(res_ols.df_model) + 1)\n",
    "w_est = w.copy()\n",
    "w_est[w != 1.0] = np.sqrt(var2) / np.sqrt(var1)\n",
    "res_fwls = sm.WLS(y, X, 1.0 / ((w_est**2))).fit()\n",
    "print(res_fwls.summary())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
